package cn.jiangzeyin.controller.classifiers;

import cn.hutool.core.io.FileUtil;
import cn.hutool.core.util.StrUtil;
import cn.jiangzeyin.common.DefaultSystemLog;
import cn.jiangzeyin.common.JsonMessage;
import com.hankcs.hanlp.classification.classifiers.IClassifier;
import com.hankcs.hanlp.classification.classifiers.NaiveBayesClassifier;
import com.hankcs.hanlp.classification.models.AbstractModel;
import com.hankcs.hanlp.corpus.io.IOUtil;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RestController;

import java.io.File;
import java.io.IOException;

/**
 * Created by jiangzeyin on 2017/12/9.
 */
@RestController
@RequestMapping("classifiers")
public class ClassificationTrain extends BaseClassification {
    private static String modelFolder;

    @RequestMapping(value = "train.json", method = RequestMethod.POST, produces = MediaType.APPLICATION_JSON_UTF8_VALUE)
    public String train(String saveModelName) {
        if (StrUtil.isEmpty(saveModelName)) {
            return JsonMessage.getString(400, "saveModelName is null");
        }
        // 防止传入当前文件夹以上的目录
        saveModelName = saveModelName.replace("..", "");
        // 分类词库路径
        String tip = doCorpusFolder();
        if (tip != null) {
            return tip;
        }
        File corpusFolderFile = new File(corpusFolder);
        String[] lists = corpusFolderFile.list();
        if (lists == null || lists.length < 1) {
            return JsonMessage.getString(401, "corpusFolder not found data");
        }
        // 模型文件存放路径
        if (StrUtil.isEmpty(modelFolder)) {
            String[] result = getModelFolder();
            if (result[0] == null) {
                modelFolder = result[1];
            } else {
                return result[0];
            }
        }
        // 创建分类器，更高级的功能请参考IClassifier的接口定义
        IClassifier classifier = new NaiveBayesClassifier();
        try {
            classifier.train(corpusFolder);
        } catch (IOException e) {
            DefaultSystemLog.ERROR().error("训练模型异常", e);
            return JsonMessage.getString(505, "train error", e.getMessage());
        }
        // 训练后的模型支持持久化，下次就不必训练了
        AbstractModel abstractModel = classifier.getModel();
        String[] catalog = abstractModel.catalog;
        if (catalog == null || catalog.length < 1) {
            return JsonMessage.getString(404, "not found train data");
        }
        String path = modelFolder + "/" + saveModelName;
        path = FileUtil.normalize(path);
        IOUtil.saveObjectTo(abstractModel, path);
        return JsonMessage.getString(200, "success");
    }
}
